{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "provenance": [],
   "gpuType": "T4",
   "machine_shape": "hm"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  },
  "language_info": {
   "name": "python"
  },
  "accelerator": "GPU",
  "gpuClass": "standard"
 },
 "cells": [
  {
   "cell_type": "code",
   "source": [
    "!nvidia-smi"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "uHJROIGn3i_Z",
    "outputId": "764afe17-db79-4832-e93b-5e59728de378"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Sat May  6 10:14:14 2023       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
      "| N/A   58C    P0    29W /  70W |  13785MiB / 15360MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "!git clone https://huggingface.co/openlm-research/open_llama_7b_preview_300bt"
   ],
   "metadata": {
    "id": "PlYcNqrMqcgy"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KQUxwEIPmF36"
   },
   "outputs": [],
   "source": [
    "!pip install -qqq transformers==4.28.1 --progress-bar off\n",
    "!pip install -qqq bitsandbytes==0.38.1 --progress-bar off\n",
    "!pip install -qqq accelerate==0.18.0 --progress-bar off\n",
    "!pip install -qqq sentencepiece==0.1.99 --progress-bar off"
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "import textwrap\n",
    "\n",
    "import torch\n",
    "from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer"
   ],
   "metadata": {
    "id": "dIaeE5sKplgx"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "def print_response(response: str):\n",
    "    print(textwrap.fill(response, width=110))"
   ],
   "metadata": {
    "id": "R7I0fMNXxc7k"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "BOS_TOKEN_ID = 1\n",
    "EOS_TOKEN_ID = 2\n",
    "MAX_TOKENS = 1024"
   ],
   "metadata": {
    "id": "HtRfNE-RrDmT"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "MODEL_NAME = \"/content/open_llama_7b_preview_300bt/open_llama_7b_preview_300bt_transformers_weights\"\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME, add_eos_token=True)\n",
    "\n",
    "model = LlamaForCausalLM.from_pretrained(\n",
    "    MODEL_NAME, local_files_only=True, torch_dtype=torch.float16, device_map=\"auto\"\n",
    ")"
   ],
   "metadata": {
    "id": "a_8RthnX9SPG"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "tokenizer.bos_token_id = BOS_TOKEN_ID"
   ],
   "metadata": {
    "id": "UcxAGhaTVO25"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Single prompt"
   ],
   "metadata": {
    "id": "oBE_HnVDKJqJ"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "prompt = \"The world's highest building is\""
   ],
   "metadata": {
    "id": "mAzCHuY9_1sn"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "generation_config = GenerationConfig(max_new_tokens=256, temperature=0.5)"
   ],
   "metadata": {
    "id": "CLHPIWub__MW"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)"
   ],
   "metadata": {
    "id": "jmtslKI1ACcX"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "inputs"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "YzAmXhqok9U-",
    "outputId": "ab8f39fd-18b4-4ca4-ec7d-b38c49ce31ed"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    1,   347,   925, 31889, 31842,  4454,  2203,   322,     0]],\n",
       "       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
      ]
     },
     "metadata": {},
     "execution_count": 11
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "%%time\n",
    "\n",
    "with torch.inference_mode():\n",
    "    tokens = model.generate(**inputs, generation_config=generation_config)"
   ],
   "metadata": {
    "id": "T4p2rFsCAHJx",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "outputId": "a1487161-a49f-442a-fb64-97fafd886b15"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "CPU times: user 18.2 s, sys: 306 ms, total: 18.5 s\n",
      "Wall time: 21.5 s\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "tokens"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ceEPtizqlOJ3",
    "outputId": "a87dcc6d-cb7e-42d6-b9ce-3e2076fd65f5"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "tensor([[    1,   347,   925, 31889, 31842,  4454,  2203,   322,     0, 31856,\n",
       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
       "         21880, 31856,   347,   925, 31889, 31842,  8522,   384,  2203,   322,\n",
       "           266,  4646, 31892, 21527, 21880, 31856,   347,   925, 31889, 31842,\n",
       "          8522,   384,  2203,   322,   266,  4646, 31892, 21527, 21880, 31856,\n",
       "           347,   925, 31889, 31842,  8522,   384,  2203,   322,   266,  4646,\n",
       "         31892, 21527, 21880, 31856,   347,   925, 31889, 31842,  8522,   384,\n",
       "          2203,   322,   266,  4646, 31892, 21527, 21880, 31856,   347,   925,\n",
       "         31889, 31842,  8522,   384,  2203,   322,   266,  4646, 31892, 21527,\n",
       "         21880, 31856,   347,   925, 31889]], device='cuda:0')"
      ]
     },
     "metadata": {},
     "execution_count": 13
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "completion = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
    "print_response(completion)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "LOYLu8ybAN0g",
    "outputId": "bcc2a9bb-29fa-4e6e-87ec-6469002da31b"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The world's highest building is. The world's tallest building is the Burj Khalifa. The world's tallest\n",
      "building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building\n",
      "is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the\n",
      "Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj\n",
      "Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa.\n",
      "The world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The\n",
      "world's tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's\n",
      "tallest building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world's tallest\n",
      "building is the Burj Khalifa. The world's tallest building is the Burj Khalifa. The world'\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prompting with sampling"
   ],
   "metadata": {
    "id": "NtvOWBP2KMzU"
   }
  },
  {
   "cell_type": "code",
   "source": [
    "def top_k_sampling(logits, k=10):\n",
    "    top_k = torch.topk(logits, k)\n",
    "    top_k_indices = top_k.indices\n",
    "    top_k_values = top_k.values\n",
    "    probabilities = torch.softmax(top_k_values, dim=-1)\n",
    "    choice = torch.multinomial(probabilities, num_samples=1)\n",
    "    token_id = int(top_k_indices[choice])\n",
    "    return token_id\n",
    "\n",
    "\n",
    "def process_chat(\n",
    "    model: LlamaForCausalLM,\n",
    "    tokenizer: LlamaTokenizer,\n",
    "    prompt: str,\n",
    "    max_new_tokens: int = 256,\n",
    "):\n",
    "    input_ids = tokenizer(prompt).input_ids\n",
    "\n",
    "    output_token_ids = list(input_ids)\n",
    "\n",
    "    max_src_len = MAX_TOKENS - max_new_tokens - 8\n",
    "    input_ids = input_ids[-max_src_len:]\n",
    "    with torch.inference_mode():\n",
    "        for i in range(max_new_tokens):\n",
    "            if i == 0:\n",
    "                out = model(\n",
    "                    input_ids=torch.as_tensor([input_ids], device=DEVICE),\n",
    "                    use_cache=True,\n",
    "                )\n",
    "                logits = out.logits\n",
    "                past_key_values = out.past_key_values\n",
    "            else:\n",
    "                out = model(\n",
    "                    input_ids=torch.as_tensor([[token_id]], device=DEVICE),\n",
    "                    use_cache=True,\n",
    "                    past_key_values=past_key_values,\n",
    "                )\n",
    "                logits = out.logits\n",
    "                past_key_values = out.past_key_values\n",
    "\n",
    "            last_token_logits = logits[0][-1]\n",
    "\n",
    "            token_id = top_k_sampling(last_token_logits)\n",
    "\n",
    "            output_token_ids.append(token_id)\n",
    "\n",
    "            if token_id == EOS_TOKEN_ID:\n",
    "                break\n",
    "\n",
    "    return tokenizer.decode(output_token_ids, skip_special_tokens=True)"
   ],
   "metadata": {
    "id": "sx_zG4GvJ7Ze"
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "source": [
    "%%time\n",
    "prompt = \"You're Michael G Scott from the office. What is your favorite phrase?\"\n",
    "response = process_chat(model, tokenizer, prompt)\n",
    "print_response(response)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "UACwKuSkIEI_",
    "outputId": "a6ffa056-fdaf-4b07-f027-7886257a66db"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "You're Michael G Scott from the office. What is your favorite phrase?. I'm a bit too old to be doing this!\n",
      "What do you think of your character in the Office? Who is your favorite person you know who is also an actor?\n",
      "How do you get your inspiration? How do you feel about working with John Krasinski? What is the hardest thing\n",
      "about being in the show? What kind of a fan are you? What's your favorite thing about being an actor?\n",
      "CPU times: user 5.44 s, sys: 0 ns, total: 5.44 s\n",
      "Wall time: 5.46 s\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "prompt = \"The world's highest building is\"\n",
    "response = process_chat(model, tokenizer, prompt)\n",
    "print_response(response)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "k6tcUkdT-H-q",
    "outputId": "feea72cd-bb93-490c-d0b0-3718f28e0807"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The world's highest building is tower, which is 1,336.8 meters from the ground in 2007. A building 1,336.8\n",
      "meters high is called tallest building in history. The tallest building in the world is currently Burj Khalifa\n",
      "(Dubai). Burj Khalifa is 829.8 meters tall. Burj Khalifa is a 77-story skyscraper. This tall building is the\n",
      "28th tallest building in the world and the 10th highest structure in the world. Burj Khalifa is a mixed use\n",
      "building. In addition to office space, there are apartments, hotels, shopping malls in the building. In\n",
      "addition to being the highest building in the world, the tower Burj Khalifa was also the first building on the\n",
      "planet to reach the height of 800 meters. This is a record that was achieved in 2010. Burj Khalifa is owned by\n",
      "the real estate developer Emaar Properties. Burj Khalifa was developed in a record time, the building was\n",
      "opened at 2010, but was first opened at the end of December 2010 and officially opened on January 4\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "prompt = \"The best way to invest $10,000 is\"\n",
    "response = process_chat(model, tokenizer, prompt)\n",
    "print_response(response)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ByY-PuTiGtRy",
    "outputId": "835e55da-831a-41d3-e064-b8c2734adf53"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The best way to invest $10,000 is The best way to invest $10,000 is to start by saving up to $5000 and then\n",
      "buy stock that’s already cheap. This is a more conservative way of investing and you can buy a good stock for\n",
      "a cheap price. I think you’re right, but the way you’re describing it is more like buying a stock for 50 cents\n",
      "a share. If you look at the way that I describe it, it’s not a stock that’s a penny a share. It’s a stock\n",
      "that’s a penny a share, so it’s a penny on a dollar. If you look at the way that I describe it, it’s a stock\n",
      "that’s a penny on a dollar. It’s a penny on a dollar, so it’s a penny on a dollar. But if you look at it, it’s\n",
      "a penny on a penny. I think it’s a penny on a penny. But the way you describe it is a penny on a penny. But if\n",
      "you look at it, it’s a penny on a penny. I think it’s a penny on the dollar. But the way you describe it is a\n",
      "penny on the penny. But\n"
     ]
    }
   ]
  },
  {
   "cell_type": "code",
   "source": [
    "prompt = \"The best make and model v8 manual gearbox car is\"\n",
    "response = process_chat(model, tokenizer, prompt)\n",
    "print_response(response)"
   ],
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e8lclPl1G60Y",
    "outputId": "2fa65dc0-27c7-4fdf-c66e-fa7e817c7fe0"
   },
   "execution_count": null,
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "The best make and model v8 manual gearbox car is 1968-1969 gmc sierra with 4spd manual or gm overdrive auto.\n",
      "If you want to know how to install a manual gearbox car, please do not hesitate to contact us. If you are\n",
      "wondering how to install a manual gearbox car, please contact us at:. If you want to know how to install a\n",
      "manual transaxle gearbox car, please read the following.\n"
     ]
    }
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## References\n",
    "\n",
    "- https://github.com/riversun/open_llama_7b_hands_on\n",
    "- https://news.ycombinator.com/item?id=35798888\n",
    "- https://github.com/openlm-research/open_llama\n",
    "- https://huggingface.co/openlm-research/open_llama_7b_preview_300bt"
   ],
   "metadata": {
    "id": "56SaOYzppR1m"
   }
  }
 ]
}